from Voicelab.pipeline.Node import Node
from parselmouth.praat import call
from Voicelab.toolkits.Voicelab.VoicelabNode import VoicelabNode
from Voicelab.toolkits.Voicelab.MeasureFormantNode import MeasureFormantNode
from Voicelab.toolkits.Voicelab.MeasureFormantPositionsNode import MeasureFormantPositionsNode

import pandas as pd
import numpy as np
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from scipy import stats


###################################################################################################
# MEASURE VOCAL TRACT ESTIMATES NODE
# WARIO pipeline node for estimating the vocal tract of a voice.
###################################################################################################
# ARGUMENTS
# 'voice'   : sound file generated by parselmouth praat
###################################################################################################
# RETURNS
###################################################################################################


class MeasureVocalTractEstimatesNode(VoicelabNode):
    def __init__(self, *args, **kwargs):
        """
        Args:
            *args:
            **kwargs:
        """
        super().__init__(*args, **kwargs)

        self.args = {
            "Measure Formant PCA": True,
            "Measure Formant Positions": True,
            "Measure Formant Dispersion": True,
            "Measure Average Formant": True,
            "Measure Geometric Mean": True,
            "Measure Fitch VTL": True,
            "Measure Delta F": True,
            "Measure VTL Delta F": True,
        }
        self.state = {
            "f1 means": [],
            "f2 means": [],
            "f3 means": [],
            "f4 means": [],
            "f1 medians": [],
            "f2 medians": [],
            "f3 medians": [],
            "f4 medians": [],
        }

    ###############################################################################################
    # process: WARIO hook called once for each voice file.
    ###############################################################################################
    def process(self):
        try:
            "vocal_tract_estimates_mean"

            voice = self.args["voice"]
            # F1 - F4 Means are calculated in the MeasureFormantNode
            f1 = self.args["F1 Mean"]
            f2 = self.args["F2 Mean"]
            f3 = self.args["F3 Mean"]
            f4 = self.args["F4 Mean"]

            f1_median = self.args["F1 Median"]
            f2_median = self.args["F2 Median"]
            f3_median = self.args["F3 Median"]
            f4_median = self.args["F4 Median"]

            self.state["f1 means"].append(f1)
            self.state["f2 means"].append(f2)
            self.state["f3 means"].append(f3)
            self.state["f4 means"].append(f4)

            self.state["f1 medians"].append(f1_median)
            self.state["f2 medians"].append(f2_median)
            self.state["f3 medians"].append(f3_median)
            self.state["f4 medians"].append(f4_median)

            # Pitch is calculated in the MeasurePitchNode and passed in here
            pitch = self.args["Pitch"]

            if self.args["Measure Formant Dispersion"]:
                formant_dispersion = (f4 - f1) / 3
            else:
                formant_dispersion = "Not Selected"

            if self.args["Measure Average Formant"]:
                average_formant = (f1 + f2 + f3 + f4) / 4
            else:
                average_formant = "Not Selected"

            if self.args["Measure Geometric Mean"]:
                geometric_mean = (f1 * f2 * f3 * f4) ** 0.25
            else:
                geometric_mean = "Not Selected"

            if self.args["Measure Fitch VTL"]:
                fitch_vtl = (
                    (1 * (35000 / (4 * f1)))
                    + (3 * (35000 / (4 * f2)))
                    + (5 * (35000 / (4 * f3)))
                    + (7 * (35000 / (4 * f4)))
                ) / 4
            else:
                fitch_vtl = "Not Selected"

            # Reby Method
            if self.args["Measure Delta F"]:
                xysum = (0.5 * f1) + (1.5 * f2) + (2.5 * f3) + (3.5 * f4)
                xsquaredsum = (0.5 ** 2) + (1.5 ** 2) + (2.5 ** 2) + (3.5 ** 2)
                delta_f = xysum / xsquaredsum
            else:
                delta_f = "Not Selected"

            if self.args["Measure VTL Delta F"]:
                vtl_delta_f = 35000 / (2 * delta_f)
            else:
                vtl_delta_f = "Not Selected"

            return {
                "formant_dispersion": formant_dispersion,
                "average_formant": average_formant,
                "geometric_mean": geometric_mean,
                "fitch_vtl": fitch_vtl,
                "delta_f": delta_f,
                "vtl_delta_f": vtl_delta_f,
            }
        except:
            return {
                "formant_dispersion": 0,
                "average_formant": 0,
                "geometric_mean": 0,
                "fitch_vtl": 0,
                "delta_f": 0,
                "vtl_delta_f": 0,
            }
    # PCA analysis is run once for all files, so we want to hook up an event for when this is
    def end(self, results):

        # we pull the measuring capabilities from the other nodes directly rather than duplicate functionality
        """
        Args:
            results:
        """
        calc_formant_pca = MeasureFormantNode().formant_pca
        calc_formant_position = MeasureFormantPositionsNode().calculate_formant_position

        formant_mean_lists = [
            self.state["f1 means"],
            self.state["f2 means"],
            self.state["f3 means"],
            self.state["f4 means"],
        ]
        formant_median_lists = [
            self.state["f1 medians"],
            self.state["f2 medians"],
            self.state["f3 medians"],
            self.state["f4 medians"],
        ]

        if self.args["Measure Formant PCA"]:
            principal_components_means, principal_components_medians = calc_formant_pca(
                formant_mean_lists, formant_median_lists
            )

            for i, result in enumerate(results):
                results[i][self]["PCA_means"] = float(principal_components_means[i, 0])
                results[i][self]["PCA_medians"] = float(
                    principal_components_medians[i, 0]
                )

        if self.args["Measure Formant Positions"]:
            formant_positions = calc_formant_position(
                formant_mean_lists, formant_median_lists
            )

            for i, result in enumerate(results):
                if isinstance(formant_positions, str):
                    results[i][self]["Formant Position"] = formant_positions
                else:
                    results[i][self]["Formant Position"] = float(formant_positions[i])

        return results
